{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Training the network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "\n",
    "Now we will see how to train the network,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class DQN(object):\n",
    "\n",
    "    # First we define the class called DQN and initialize all varaiables in __init__ method\n",
    "\n",
    "    def __init__(self, state_size,\n",
    "                       action_size,\n",
    "                       session,\n",
    "                       summary_writer = None,\n",
    "                       exploration_period = 1000,\n",
    "                       minibatch_size = 32,\n",
    "                       discount_factor = 0.99,\n",
    "                       experience_replay_buffer = 10000,\n",
    "                       target_qnet_update_frequency = 10000,\n",
    "                       initial_exploration_epsilon = 1.0,\n",
    "                       final_exploration_epsilon = 0.05,\n",
    "                       reward_clipping = -1,\n",
    "                        ):\n",
    "\n",
    "       \n",
    "        self.state_size = state_size\n",
    "        self.action_size = action_size\n",
    "\n",
    "\n",
    "        self.session = session\n",
    "        self.exploration_period = float(exploration_period)\n",
    "        self.minibatch_size = minibatch_size\n",
    "        self.discount_factor = tf.constant(discount_factor)\n",
    "        self.experience_replay_buffer = experience_replay_buffer\n",
    "        self.summary_writer = summary_writer\n",
    "        self.reward_clipping = reward_clipping\n",
    "\n",
    "\n",
    "        self.target_qnet_update_frequency = target_qnet_update_frequency\n",
    "        self.initial_exploration_epsilon = initial_exploration_epsilon\n",
    "        self.final_exploration_epsilon = final_exploration_epsilon\n",
    "        self.num_training_steps = 0\n",
    "\n",
    "\n",
    "        # initialize primary DDQN by creating an instance to our QNetworkDueling class\n",
    "        self.qnet = QNetworkDueling(self.state_size, self.action_size, \"qnet\")\n",
    "\n",
    "        # similarly initialize the Target DDQN\n",
    "        self.target_qnet = QNetworkDueling(self.state_size, self.action_size, \"target_qnet\")\n",
    "\n",
    "        # Next initialize the optimizer as a RMSPropOptimizer\n",
    "        self.qnet_optimizer = tf.train.RMSPropOptimizer(learning_rate=0.00025, decay=0.99, epsilon=0.01) \n",
    "\n",
    "        # Now, initialize experience replay buffer by creating instance to our ReplayMemoryFast class\n",
    "        self.experience_replay = ReplayMemoryFast(self.experience_replay_buffer, self.minibatch_size)\n",
    "\n",
    "        # Setup the computation graph\n",
    "        self.create_graph()\n",
    "\n",
    "\n",
    "    # Next we define the function called copy_to_target_network for copying weights from the\n",
    "    # primary network to our target network\n",
    "\n",
    "    def copy_to_target_network(source_network, target_network):\n",
    "        target_network_update = []\n",
    "        for v_source, v_target in zip(source_network.variables(), target_network.variables()):\n",
    "            # this is equivalent to target = source\n",
    "            update_op = v_target.assign(v_source)\n",
    "            target_network_update.append(update_op)\n",
    "        return tf.group(*target_network_update)\n",
    "\n",
    "    \n",
    "    # Now we define the function called create graph and build our computation graph\n",
    "    def create_graph(self):\n",
    "\n",
    "        # we calculate Q values and select the action that has maximum Q value\n",
    "        with tf.name_scope(\"pick_action\"):\n",
    "            \n",
    "            # placeholder for state\n",
    "            self.state = tf.placeholder(tf.float32, (None,)+self.state_size , name=\"state\")\n",
    "\n",
    "            # placeholder for q values\n",
    "            self.q_values = tf.identity(self.qnet(self.state) , name=\"q_values\")\n",
    "\n",
    "            # placeholder for predicted actions\n",
    "            self.predicted_actions = tf.argmax(self.q_values, dimension=1 , name=\"predicted_actions\")\n",
    "\n",
    "            # plot as a historgram to track max q values\n",
    "            tf.histogram_summary(\"Q values\", tf.reduce_mean(tf.reduce_max(self.q_values, 1))) # save max q-values to track learning\n",
    "\n",
    "\n",
    "       \n",
    "       # Next we calculate target future reward\n",
    "        with tf.name_scope(\"estimating_future_rewards\"):\n",
    "          \n",
    "            self.next_state = tf.placeholder(tf.float32, (None,)+self.state_size , name=\"next_state\")\n",
    "            self.next_state_mask = tf.placeholder(tf.float32, (None,) , name=\"next_state_mask\") # 0 for terminal states\n",
    "            self.rewards = tf.placeholder(tf.float32, (None,) , name=\"rewards\")\n",
    "\n",
    "            self.next_q_values_targetqnet = tf.stop_gradient(self.target_qnet(self.next_state), name=\"next_q_values_targetqnet\")\n",
    "\n",
    "    \n",
    "            self.next_q_values_qnet = tf.stop_gradient(self.qnet(self.next_state), name=\"next_q_values_qnet\")\n",
    "            self.next_selected_actions = tf.argmax(self.next_q_values_qnet, dimension=1)\n",
    "            self.next_selected_actions_onehot = tf.one_hot(indices=self.next_selected_actions, depth=self.action_size)\n",
    "\n",
    "            self.next_max_q_values = tf.stop_gradient( tf.reduce_sum( tf.mul( self.next_q_values_targetqnet, self.next_selected_actions_onehot ) , reduction_indices=[1,] ) * self.next_state_mask )\n",
    "\n",
    "\n",
    "            self.target_q_values = self.rewards + self.discount_factor*self.next_max_q_values\n",
    "\n",
    "\n",
    "\n",
    "        # perform the optimization\n",
    "        with tf.name_scope(\"optimization_step\"):\n",
    "            self.action_mask = tf.placeholder(tf.float32, (None, self.action_size) , name=\"action_mask\") \n",
    "            self.y = tf.reduce_sum( self.q_values * self.action_mask , reduction_indices=[1,])\n",
    "\n",
    "            # clip the errors\n",
    "            self.error = tf.abs(self.y - self.target_q_values)\n",
    "            quadratic_part = tf.clip_by_value(self.error, 0.0, 1.0)\n",
    "            linear_part = self.error - quadratic_part\n",
    "            self.loss = tf.reduce_mean( 0.5*tf.square(quadratic_part) + linear_part )\n",
    "\n",
    "            # optimize the gradients\n",
    "            qnet_gradients = self.qnet_optimizer.compute_gradients(self.loss, self.qnet.variables())\n",
    "\n",
    "            for i, (grad, var) in enumerate(qnet_gradients):\n",
    "                if grad is not None:\n",
    "                    qnet_gradients[i] = (tf.clip_by_norm(grad, 10), var)\n",
    "\n",
    "            self.qnet_optimize = self.qnet_optimizer.apply_gradients(qnet_gradients)\n",
    "\n",
    "        # Copy the primary network weights to the target network\n",
    "        with tf.name_scope(\"target_network_update\"):\n",
    "            self.hard_copy_to_target = DQN.copy_to_target_network(self.qnet, self.target_qnet)\n",
    "\n",
    "\n",
    "\n",
    "    # We define the function called store for storing all the experience in the experience replay buffer\n",
    "\n",
    "    def store(self, state, action, reward, next_state, is_terminal):\n",
    "        # rewards clipping\n",
    "        if self.reward_clipping > 0.0:\n",
    "            reward = np.clip(reward, -self.reward_clipping, self.reward_clipping)\n",
    "\n",
    "        self.experience_replay.store(state, action, reward, next_state, is_terminal)\n",
    "\n",
    "\n",
    "    # We define a function called action for selecting actions using decaying epsilon greedy policy\n",
    "    \n",
    "    def action(self, state, training = False):\n",
    "     \n",
    "        if self.num_training_steps > self.exploration_period:\n",
    "            epsilon = self.final_exploration_epsilon\n",
    "        else:\n",
    "            epsilon =  self.initial_exploration_epsilon - float(self.num_training_steps) * (self.initial_exploration_epsilon - self.final_exploration_epsilon) / self.exploration_period\n",
    "\n",
    "        if not training:\n",
    "            epsilon = 0.05\n",
    "\n",
    "        # execute a random action with probability epsilon, or follow the QNet policy with probability 1-epsilon.\n",
    "        if random.random() <= epsilon:\n",
    "            action = random.randint(0, self.action_size-1)\n",
    "        else:\n",
    "            action = self.session.run(self.predicted_actions, {self.state:[state] } )[0]\n",
    "\n",
    "        return action\n",
    "\n",
    "\n",
    "    # Now we define a function called train for training our network\n",
    "\n",
    "    def train(self):\n",
    "        # Copy the QNetwork weights to the Target QNetwork.\n",
    "        if self.num_training_steps == 0:\n",
    "            print \"Training starts...\"\n",
    "            self.qnet.copy_to(self.target_qnet)\n",
    "\n",
    "\n",
    "        # Sample experience from replay memory\n",
    "        minibatch = self.experience_replay.sample()\n",
    "        if len(minibatch)==0:\n",
    "            return\n",
    "\n",
    "\n",
    "        # get the states, actions, rewards and next states from the minibatch\n",
    "        batch_states = np.asarray( [d[0] for d in minibatch] )\n",
    "        actions = [d[1] for d in minibatch]\n",
    "        batch_actions = np.zeros( (self.minibatch_size, self.action_size) )\n",
    "        for i in xrange(self.minibatch_size):\n",
    "            batch_actions[i, actions[i]] = 1\n",
    "\n",
    "        batch_rewards = np.asarray( [d[2] for d in minibatch] )\n",
    "        batch_newstates = np.asarray( [d[3] for d in minibatch] )\n",
    "\n",
    "        batch_newstates_mask = np.asarray( [not d[4] for d in minibatch] )\n",
    "\n",
    "\n",
    "        # Perform the training operation\n",
    "        scores, _, = self.session.run([self.q_values, self.qnet_optimize],\n",
    "                                      { self.state: batch_states,\n",
    "                                        self.next_state: batch_newstates,\n",
    "                                        self.next_state_mask: batch_newstates_mask,\n",
    "                                        self.rewards: batch_rewards,\n",
    "                                        self.action_mask: batch_actions} )\n",
    "\n",
    "\n",
    "        if self.num_training_steps % self.target_qnet_update_frequency == 0:\n",
    "\n",
    "            self.session.run( self.hard_copy_to_target )\n",
    "\n",
    "\n",
    "            # Write logs\n",
    "            print 'mean maxQ in minibatch: ',np.mean(np.max(scores,1))\n",
    "\n",
    "            str_ = self.session.run(self.summarize, { self.state: batch_states,\n",
    "                                        self.next_state: batch_newstates,\n",
    "                                        self.next_state_mask: batch_newstates_mask,\n",
    "                                        self.rewards: batch_rewards,\n",
    "                                        self.action_mask: batch_actions})\n",
    "\n",
    "            self.summary_writer.add_summary(str_, self.num_training_steps)\n",
    "\n",
    "\n",
    "        self.num_training_steps += 1\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:anaconda]",
   "language": "python",
   "name": "conda-env-anaconda-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
